-
Notifications
You must be signed in to change notification settings - Fork 636
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 4 comments
| quantizer=fc2_input_quantizers[group_idx], | ||
| requires_grad=False, | ||
| with_gemm_swizzled_scales=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect grad-required flags
In ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.fuser_forward, swiglu_ctx.input_requires_grad and swiglu_ctx.extra_input_requires_grad are set to True unconditionally (and input_requires_grad is set to requires_grad unconditionally). This will make ScaledSwiGLU.fuser_backward compute grad_input and grad_extra_input even when neither input_ nor scales require grads, which violates autograd semantics and can raise (e.g., scales.detach() passed into the fused kernel, but extra_input_requires_grad=True forces a gradient).
This should be set based on the actual requirements:
input_requires_grad = input_.requires_gradswiglu_ctx.extra_input_requires_grad = scales.requires_grad- and for FC weights, check each parameter’s
requires_grad(not justweight0).
| # Return immediately if fused kernel is not supported | ||
| if not BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): | ||
| return ops | ||
|
|
||
| # Check if recipe is supported | ||
| if recipe is None: | ||
| return ops | ||
| if not recipe.mxfp8(): | ||
| return ops | ||
|
|
||
| # Scan through ops, fusing if possible | ||
| out = [] | ||
| window, ops = ops[:3], ops[3:] | ||
| while len(window) == 3: | ||
|
|
||
| # Check if window matches pattern | ||
| matches_pattern = True | ||
| if not ( | ||
| isinstance(window[0], GroupedLinear) | ||
| and isinstance(window[1], ScaledSwiGLU) | ||
| and isinstance(window[2], GroupedLinear) | ||
| ): | ||
| matches_pattern = False | ||
| elif window[0].has_bias or window[2].has_bias: | ||
| matches_pattern = False | ||
| elif window[0].num_groups != window[2].num_groups: | ||
| matches_pattern = False | ||
| elif ( | ||
| window[0].in_features % 256 != 0 | ||
| or window[0].out_features % 256 != 0 | ||
| or window[2].in_features % 256 != 0 | ||
| or window[2].out_features % 256 != 0 | ||
| ): | ||
| matches_pattern = False | ||
| elif window[1].glu_interleave_size != 32: | ||
| matches_pattern = False | ||
|
|
||
| if matches_pattern: | ||
| # Construct fused op if window matches pattern | ||
| op = BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8( | ||
| fc1=window[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken fusion window scan
Both fuse_backward_ops and fuse_forward_ops have a window/shift loop that can drop or reorder ops when the pattern doesn’t match. In the non-matching branch you do out.extend(window[:-2]); window = window[-2:] and then immediately do out.extend(window[:-3]) (which is a no-op for a 2-element window) before refilling. This causes the scan to advance by 1 op in some cases and by 2 in others, and it never emits window[-1] until the very end. For sequences like [A,B,C,D] where [A,B,C] doesn’t match but [B,C,D] would (or vice versa), this loop will not correctly consider all 3-op windows and can produce an incorrect fused op list.
This needs a standard sliding-window approach (advance by 1 when not matching; replace 3->1 when matching) to ensure no ops are skipped or duplicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your reading of the code is incorrect, although I would take suggestions for improving clarity. The loop condition guarantees the window size is 3 at the beginning of each iteration. We either fuse the window or eject the first op in the window, and then we refill back up to window size 3.
The only hint of an edge case I can see is if we perform a fusion, and that fused op can participate in further fusions. Then we might want to rewind the sliding window so that we reexamine the fused op in each window position. However, we know that the fused op is final, and we can safely advance the window past it.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Review suggestion from @ptrendx. Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: